package code;

import java.util.ArrayList;
import java.util.List;


public class TreeFlatten {
	static public class TreeNode {
	     int val;
	     TreeNode left;
	     TreeNode right;
	     TreeNode(int x) { val = x; }
	}
	
	public void dfs(TreeNode p,List<TreeNode> list){
		if(p==null)	return ;
		list.add(p);
		dfs(p.left,list);
		dfs(p.right,list);
		
	}
    public void flatten(TreeNode root) {
        if(root==null)	return;
        List<TreeNode> list=new ArrayList<TreeNode>();
        TreeNode p=root;
        dfs(p,list);
        for(int i=0;i<list.size()-1;i++){
        	p=list.get(i);
        	p.right=list.get(i+1);
        	p.left=null;
        }
    }
    
    public static void main(String[] args){
    	
    }
}
